from emodel_generalisation.mcmc import load_chains, plot_corner
import matplotlib as mpl
import numpy as np
from emodel_generalisation.utils import get_feature_df
from bluepyparallel import init_parallel_factory
import yaml
from emodel_generalisation.model.modifiers import synth_soma, synth_axon
from functools import partial
from emodel_generalisation.model.access_point import AccessPoint
from datareuse import Reuse
from emodel_generalisation.model.evaluation import feature_evaluation
from emodel_generalisation.mcmc import save_selected_emodels
from itertools import cycle
import json
from matplotlib.backends.backend_pdf import PdfPages
import pandas as pd
from pathlib import Path
from bluepyopt.ephys.responses import TimeVoltageResponse
import matplotlib.pyplot as plt
import pickle
from emodel_generalisation.utils import get_combo_hash
import extra_features


def plot_traces(trace_df, trace_path="traces", pdf_filename="traces.pdf"):
    """Plot traces from df, with highlighs on rows with trace_highlight = True.

    Args:
        trace_df (DataFrame): contains list of combos with traces to plot
        trace_path (str): path to folder with traces in .pkl
        pdf_filename (str): name of pdf to save
    """
    COLORS = cycle(["r"] + [f"C{i}" for i in range(10)])
    trace_df = trace_df.copy()  # prevents annyoing panda warnings
    if "trace_highlight" not in trace_df.columns:
        trace_df["trace_highlight"] = True
    for index in trace_df.index:
        if trace_df.loc[index, "trace_highlight"]:
            c = next(COLORS)

        if "trace_data" in trace_df.columns:
            trace_path = trace_df.loc[index, "trace_data"]
        else:
            combo_hash = get_combo_hash(trace_df.loc[index])
            trace_path = Path(trace_path) / ("trace_id_" + str(combo_hash) + ".pkl")

        with open(trace_path, "rb") as f:
            trace = pickle.load(f)
            if isinstance(trace, list):
                trace = trace[1]  # newer version the response are here
            for protocol, response in trace.items():

                # if protocol == "Step_ReboundBurst_burst.soma.v":
                if isinstance(response, TimeVoltageResponse):
                    plt.figure(protocol, figsize=(30, 7))
                    plt.gca().set_xlim(4500, 14000)
                    plt.plot(response["time"], response["voltage"], c="k", lw=1)

    with PdfPages(pdf_filename) as pdf:
        for fig_id in plt.get_fignums():
            fig = plt.figure(fig_id)
            # plt.legend(loc="best")
            plt.suptitle(fig.get_label())
            pdf.savefig()
            plt.close()


if __name__ == "__main__":
    parallel_factory = init_parallel_factory("multiprocessing")
    # parallel_factory = init_parallel_factory("dask_dataframe")
    name = "runaway"
    access_point = AccessPoint(
        emodel_dir=".",
        recipes_path="config/recipes.json",
        final_path=f"../final_{name}.json",
        with_seeds=True,
    )
    exemplar_data = yaml.safe_load(open("../../../mcmc_run/exemplar_data.yaml"))
    access_point.morph_path = exemplar_data["paths"]["all"]
    access_point.settings["morph_modifiers"] = [
        partial(synth_soma, params=exemplar_data["soma"], scale=1.0),
        partial(synth_axon, params=exemplar_data["ais"]["popt"], scale=1.0),
    ]

    df = pd.DataFrame()
    final = json.load(open(f"../final_{name}.json"))
    gbar_it2_soma = final["simplest_0"]["params"]["gcabar_it2.somatic"]
    gbar_it2_basal = final["simplest_0"]["params"]["gcabar_it2.basal"]
    shift_it2 = final["simplest_0"]["params"]["shift_it2.somadend"]
    params = final["simplest_0"]["params"]
    df.loc[0, "name"] = "simplest"
    df.loc[0, "emodel"] = "simplest_0"
    df.loc[0, "new_parameters"] = json.dumps({})

    shift = 3.0
    fracs = np.linspace(0.0, 1.0, 7)
    for i, frac in enumerate(fracs):
        df.loc[i + 1, "name"] = f"split_{i}"
        df.loc[i + 1, "emodel"] = "simplest_split"
        params.update(
            {
                "gcabar_it2_low.somatic": frac * gbar_it2_soma,
                "gcabar_it2_low.basal": frac * gbar_it2_basal,
                "shift_it2_low.somadend": shift_it2 + shift,
                "gcabar_it2.somatic": 0.0,
                "gcabar_it2.basal": 0.0,
                "shift_it2.somadend": shift_it2,
                "gcabar_it2_high.somatic": (1 - frac) * gbar_it2_soma,
                "gcabar_it2_high.basal": (1 - frac) * gbar_it2_basal,
                "shift_it2_high.somadend": shift_it2 - shift,
            }
        )
        df.loc[i + 1, "new_parameters"] = json.dumps(params)
    print(df)
    Path("traces").mkdir(exist_ok=True)
    with Reuse("eval_split.csv") as reuse:
        df = reuse(
            feature_evaluation,
            df,
            access_point,
            parallel_factory=parallel_factory,
            trace_data_path="traces",
            timeout=10000000,
            # record_ions_and_currents=True,
        )
    for i in df.index:
        plot_traces(df.loc[[i]], pdf_filename=f"traces_{i}.pdf")
    plt.close("all")

    print(df)
    d = get_feature_df(df)

    print(d)
    columns = [c for c in d.columns if len(c.split(".")[0].split("_")) == 4]

    cmap = plt.get_cmap("coolwarm")
    colors = cmap(np.linspace(0, 1, len(d) - 1))
    plt.figure(figsize=(6, 3))
    for gid, c in zip(d.index[::-1], colors[::-1].tolist() + ["C2"]):
        _df = pd.DataFrame()
        i = 0
        for col in columns:
            if col.endswith(".voltage_base"):
                _df.loc[i, "burst_number"] = d.loc[
                    gid, ".".join(col.split(".")[:-1] + ["all_burst_number"])
                ]
                _df.loc[i, "holding_voltage"] = d.loc[gid, col]
                i += 1
        ls = "-"
        if gid == 0:
            ls = ""
        plt.plot(_df["holding_voltage"], _df["burst_number"], marker="+", c=c, ls=ls)
    plt.xlabel("holding_voltage")
    plt.ylabel("burst_number")
    # plt.legend()
    plt.axis([-80, -50, 0, 37])
    plt.axvline(-65.3 - 2.959295568009344, c="k", ls="--")
    plt.axvline(-65.3 - 2.959295568009344 - 3, c="k", ls="-")
    plt.axvline(-65.3 - 2.959295568009344 + 3, ls="-", c="k")
    norm = mpl.colors.Normalize(vmin=fracs[0], vmax=fracs[-1])
    plt.colorbar(
        mpl.cm.ScalarMappable(norm=norm, cmap=cmap),
        ax=plt.gca(),
        orientation="vertical",
        shrink=0.6,
        label="it2 left/right fraction",
    )
    plt.tight_layout()
    plt.savefig("shift_burst_scan.pdf")
    plt.close()
